Switch Docker release to CUDA 11.7 (#94818) Switch Docker release to CUDA 11.7 Remove `ptxas` installation logic as Trition is now bundled with ptxas Successful run: https://github.com/pytorch/pytorch/actions/runs/4176843201/jobs/7233661196 Pull Request resolved: https://github.com/pytorch/pytorch/pull/94818 Approved by: https://github.com/malfet
diff --git a/Dockerfile b/Dockerfile index ce420dc..e6ade30 100644 --- a/Dockerfile +++ b/Dockerfile
@@ -60,7 +60,7 @@ FROM conda as conda-installs ARG PYTHON_VERSION=3.8 -ARG CUDA_VERSION=11.6 +ARG CUDA_VERSION=11.7 ARG CUDA_CHANNEL=nvidia ARG INSTALL_CHANNEL=pytorch-nightly # Automatically set by buildx @@ -68,7 +68,7 @@ RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERSION} ARG TARGETPLATFORM -# On arm64 we can only install wheel packages +# On arm64 we can only install wheel packages. RUN case ${TARGETPLATFORM} in \ "linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio torchtext ;; \ *) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchaudio torchtext "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ @@ -89,11 +89,6 @@ COPY --from=conda-installs /opt/conda /opt/conda RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then \ apt install -y --no-install-recommends gcc; \ - CU_VER=$(echo $CUDA_VERSION | cut -d'.' -f 1-2) && \ - mkdir -p /usr/local/triton-min-cuda-${CU_VER} && \ - ln -s /usr/local/triton-min-cuda-${CU_VER} /usr/local/cuda; \ - mkdir -p /usr/local/cuda/bin; cp /opt/conda/bin/ptxas /usr/local/cuda/bin; \ - mkdir -p /usr/local/cuda/include; cp /opt/conda/include/cuda.h /usr/local/cuda/include; \ fi RUN rm -rf /var/lib/apt/lists/* ENV PATH /opt/conda/bin:$PATH diff --git a/docker.Makefile b/docker.Makefile index f85a3c3..fd49964 100644 --- a/docker.Makefile +++ b/docker.Makefile
@@ -8,7 +8,7 @@ DOCKER_ORG = $(shell whoami) endif -CUDA_VERSION = 11.6.2 +CUDA_VERSION = 11.7.0 CUDNN_VERSION = 8 BASE_RUNTIME = ubuntu:18.04 BASE_DEVEL = nvidia/cuda:$(CUDA_VERSION)-cudnn$(CUDNN_VERSION)-devel-ubuntu18.04